# Most of this code is from https://github.com/ultmaster/neuralpredictor.pytorch 
# which was authored by Yuge Zhang, 2020

import torch
import torch.nn as nn
import torch.nn.functional as F


def normalize_adj(adj):
    # Row-normalize matrix
    last_dim = adj.size(-1)
    rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
    return torch.div(adj, rowsum)


class MLP(nn.Module):
    def __init__(self, input_dim, linear_hidden=512):
        super().__init__()
        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(input_dim, linear_hidden)
        self.fc2 = nn.Linear(linear_hidden, linear_hidden)
        self.fc3 = nn.Linear(linear_hidden, 1)

    def forward(self, inputs):
        numv, adj, out = inputs["num_vertices"], inputs["adjacency"], inputs["operations"]

        gs = adj.size(1)  # graph node number
        adj_with_diag = normalize_adj(adj + torch.eye(gs, device=adj.device))  # assuming diagonal is not 1

        # for layer in self.gcn:
        #     out = layer(out, adj_with_diag)
        # out = graph_pooling(out, numv)

        vector_adj = adj_with_diag.view(adj_with_diag.shape[0], -1)
        vector_out = out.view(out.shape[0], -1)
        input = torch.cat((vector_adj, vector_out), dim=1)

        out = F.relu(self.fc1(input))
        out = self.dropout(out)
        out = F.relu(self.fc2(out))
        out = self.dropout(out)
        out = self.fc3(out).view(-1)
        return out
